import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision as tv, torchvision.transforms as tr
import wandb

import os
import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
import sklearn.metrics
from data.dl_getter import get_transform
from data.ds import Non_dataset
from tool.util import get_valid_unit

from data.ds import ood_root
total_ds = ['cifar10', 'svhn', 'dtd', 'iSUN', 'LSUN', 'places365', 'LSUN_R', 'cifar100',
            'interp', 'celeba', 'N', 'U', 'OODomain', 'Constant']


def ana_entropy(model, vl_dl, text_log, args=None):
    text_log['entropy'] = dict()
    for ood_dataset in total_ds:
        is_interp = ood_dataset == 'interp'
        loader = vl_dl if ood_dataset in [args.dataset, 'interp'] else get_dataset(ood_dataset, args)
        entropy = calculate_entropy(model, loader, is_interp)
        entropy = get_valid_unit(entropy.item())
        text_log['entropy'][f"{ood_dataset}"] = entropy
        print(f"Entropy Score {ood_dataset} : {entropy}")


@torch.no_grad()
def calculate_entropy(model, loader, interp=False):
    scores = []
    if interp:
        for i, (x, _) in enumerate(loader):
            if i > 0:
                if x.shape[0] != last_batch.shape[0]:
                    continue
                x_mix = (x + last_batch) / 2
                logits = model(x_mix.cuda())
                probs = torch.softmax(logits, dim=1)
                scores.append(probs.max(dim=1)[0])
            last_batch = x
    else:
        for x, _ in loader:
            x = x.cuda()
            logits = model(x)
            probs = torch.softmax(logits, dim=1)
            scores.append(probs.max(dim=1)[0])
    scores = torch.cat(scores).detach()
    entropy = -torch.sum(scores * torch.log(scores + 1e-10), dim=-1) / len(scores)
    return entropy    


def ana_fpr95(model, vl_dl, text_log, args=None):
    print("FPR95 Evaluation")
    text_log['fpr95'] = dict()
    for ood_dataset in total_ds:
        for score_fn_type in ["p_x", "p_y|x"]:
            score = calculate_fpr95(model, vl_dl, ood_dataset, score_fn_type, args)
            score = get_valid_unit(score)*100
            text_log['fpr95'][f"{ood_dataset}_{score_fn_type}"] = score
            print(f"FPR95 Score {args.dataset}(in) ~ {ood_dataset}(out) {score_fn_type} : {score}")


# https://github.com/deeplearning-wisc/vit-spurious-robustness/blob/32baba08a23712c10c8791b8ebc1dbf6289e9005/evaluation_utils/calMetric.py#L75
@torch.no_grad()
def calculate_fpr95(model, vl_dl, ood_dataset, score_fn_type, args):
    def score_fn(x):
        if score_fn_type == "p_x":
            logits = model(x).detach().cpu()
            return logits.logsumexp(1)
        elif score_fn_type == "p_y|x":
            logits = model(x)
            return nn.Softmax(1)(logits).max(1)[0].detach().cpu()
            
    in_scores, out_scores = [], []
    dload_fake = get_dataset(ood_dataset, args) if ood_dataset != "interp" else vl_dl
    for x, _ in vl_dl:
        x = x.cuda()
        scores = score_fn(x)
        in_scores.append(scores)

    if ood_dataset == "interp":
        for i, (x, _) in enumerate(dload_fake):
            if i > 0:
                if x.shape[0] != last_batch.shape[0]:
                    continue
                x_mix = (x + last_batch) / 2
                scores = score_fn(x_mix.cuda())
                out_scores.append(scores)
            last_batch = x
    else:
        for x, _ in dload_fake:
            x = x.cuda()
            scores = score_fn(x)
            out_scores.append(scores)

    in_scores = torch.cat(in_scores).detach().cpu().numpy()
    out_scores = torch.cat(out_scores).detach().cpu().numpy()

    pos = out_scores.reshape((-1, 1))
    neg = in_scores.reshape((-1, 1))
    examples = np.squeeze(np.vstack((pos, neg)))
    labels = np.zeros(len(examples), dtype=np.int32)
    labels[len(pos):] = 1

    y_true = (labels == 1)

    desc_score_indices = np.argsort(examples, kind='mergesort')[::-1]
    y_score = examples[desc_score_indices]
    y_true = y_true[desc_score_indices]

    distinct_value_indices = np.where(np.diff(y_score))[0]
    threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]

    tps = np.cumsum(y_true)[threshold_idxs]
    fps = 1 + threshold_idxs - tps

    thresholds = y_score[threshold_idxs]

    recall = tps / tps[-1]
    last_ind = tps.searchsorted(tps[-1])
    sl = slice(last_ind, None, -1)
    recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]
    
    cutoff = np.argmin(np.abs(recall - 0.95))

    return fps[cutoff] / (np.sum(np.logical_not(y_true)))

    # start = np.min([np.min(in_scores), np.min(out_scores)])
    # end = np.max([np.max(in_scores), np.max(out_scores)])
    # gap = (end- start)/10000
   
    # total = 0.0
    # fprs = 0.0 
    # for delta in np.arange(start, end, gap):
    #     tpr = np.sum(in_scores > delta) / np.float(len(in_scores))
    #     fpr = np.sum(out_scores > delta) / np.float(len(out_scores))
    #     if tpr <=0.9505 and tpr >= 0.9495:
    #         fprs += fpr
    #         total += 1
    # fprNew = fprs/total
    # return fprNew 


def ana_oodauc(model, vl_dl, text_log, args):
    scores = {}
    print("OODAUC Evaluation")

    text_log['ood_auc'] = dict()
    text_log['ood_aupr'] = dict()
    for ood_dataset in total_ds:
        for score_fn_type in ["p_x", "p_y|x"]:
            score, aupr = OODAUC(
                model, vl_dl, ood_dataset, sigma=0.,
                score_fn_type=score_fn_type, args=args)
            scores[f"{ood_dataset}_{score_fn_type}"] = score
            print(f"AUC Score {args.dataset}(in) ~ {ood_dataset}(out) ~ {score_fn_type} : {score}")
            print(f"AUPR Score {args.dataset}(in) ~ {ood_dataset}(out) ~ {score_fn_type} : {aupr}")
            text_log['ood_auc'][f"{ood_dataset}_{score_fn_type}"] = round(score, 4)*100
            text_log['ood_aupr'][f"{ood_dataset}_{score_fn_type}"] = round(aupr, 4)*100

    wandb.log({**{
        f"ood/auc_{k}": v for k, v in scores.items()
        }}, commit=False)
    total_score = sum(list(scores.values()))
    return total_score


def OODAUC(
        model,
        dload_real,
        ood_dataset,
        sigma=0,
        score_fn_type='px',
        args = None,
        num_workers=0,
        device='cuda'):

    def grad_norm(x):
        x_k = torch.autograd.Variable(x, requires_grad=True)
        f_prime = torch.autograd.grad(
            model(x_k).sum(), [x_k], retain_graph=True)[0]
        grad = f_prime.view(x.size(0), -1)
        return grad.norm(p=2, dim=1)

    def score_fn(x):
        if score_fn_type == "p_x":
            logits = model(x).detach().cpu()
            return logits.logsumexp(1)
        elif score_fn_type == "p_y|x":
            logits = model(x)
            return nn.Softmax(1)(logits).max(1)[0].detach().cpu()
        else:
            return -grad_norm(x).detach().cpu()

    dload_fake = get_dataset(ood_dataset, args) if ood_dataset != "interp" else dload_real
    real_scores = []
    # print(f'dataset {ood_dataset}')
    # print("Real scores...")
    for x, _ in dload_real:
        x = x.to(device)
        scores = score_fn(x)
        real_scores.append(scores.numpy())

    fake_scores = []
    # print("Fake scores...")
    if ood_dataset == "interp":
        last_batch = None
        for i, (x, _) in enumerate(dload_fake):
            x = x.to(device)
            if i > 0:
                if x.shape[0] != last_batch.shape[0]:
                    continue
                x_mix = (x + last_batch) / 2 + sigma * torch.randn_like(x)
                scores = score_fn(x_mix)
                fake_scores.append(scores.numpy())
            last_batch = x
    else:
        for i, (x, _) in enumerate(dload_fake):
            x = x.to(device)
            scores = score_fn(x)
            fake_scores.append(scores.numpy())
    real_scores = np.concatenate(real_scores)
    fake_scores = np.concatenate(fake_scores)
    # fake_scores = fake_scores[:10000]
    real_labels = np.ones_like(real_scores)
    fake_labels = np.zeros_like(fake_scores)
    scores = np.concatenate([real_scores, fake_scores])
    labels = np.concatenate([real_labels, fake_labels])
    score = sklearn.metrics.roc_auc_score(labels, scores)
    aupr = sklearn.metrics.average_precision_score(labels, scores)
    return score, aupr


def ana_logp_hist(model, vl_dl, save_dir="./logp_hist", args=None):
    hist_ds = ['cifar10', 'svhn', 'cifar100', 'celeba']
    ood_datasets = [ds for ds in hist_ds if ds != args.dataset]
    
    datasets_list = [
        [args.dataset, ood_datasets[0]],
        [args.dataset, ood_datasets[1]],
        [args.dataset, ood_datasets[2]],
    ]
    for datasets in datasets_list:
        for score_fn_type in ["p_x", "p_y|x"]:
            logp_hist(
                model, vl_dl, n_steps=20,
                score_fn_type=score_fn_type,
                sigma=0., datasets_list=datasets, args=args,
                save_dir=f"{save_dir}/{datasets[0]}~{datasets[1]}~{score_fn_type}")


def logp_hist(model, loader, n_steps, score_fn_type,
              sigma, datasets_list, args, save_dir, device='cuda'):
    print(f"Logp hist - {score_fn_type} - {datasets_list}")

    def sample(x, n_steps=n_steps):
        x_k = torch.autograd.Variable(x.clone(), requires_grad=True)
        # sgld
        for k in range(n_steps):
            f_prime = torch.autograd.grad(model(x_k).sum(), [x_k], retain_graph=True)[0]
            x_k.data += f_prime + 1e-2 * torch.randn_like(x_k)
        final_samples = x_k.detach()
        return final_samples

    def score_fn(x):
        if score_fn_type == "p_x":
            logits = model(x).detach().cpu()
            return logits.logsumexp(1)
        elif score_fn_type == "p_y|x":
            logits = model(x)
            return nn.Softmax(1)(logits).max(1)[0].detach().cpu()
        else:
            return model(x).max(1)[0].detach().cpu()

    score_dict = {}
    for idx, dataset_name in enumerate(datasets_list):
        dataloader = loader if idx == 0 else get_dataset(dataset_name, args)
        this_scores = []
        for x, _ in dataloader:
            x = x.to(device)
            scores = score_fn(x)
            this_scores.extend(scores.numpy())
        score_dict[dataset_name] = this_scores
    for name, scores in score_dict.items():
        plt.hist(scores, label=name, bins=100, density=True, alpha=.5)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.legend(fontsize=14, loc='upper left')
    plt.savefig(save_dir + ".png")
    plt.clf()
    wandb.log({
        f'logp_hist/{datasets_list}~{score_fn_type}': \
            wandb.Image(save_dir + ".png")
        })


def get_dataset(ood_dataset, args):
    if args.crl:
        ood_transform = tr.Compose([tr.Resize((224, 224)), get_transform()])
        resize_transform = tr.Compose([tr.Resize((224, 224)), get_transform()])
    else:
        ood_transform = get_transform()
        resize_transform = tr.Compose([tr.Resize((32, 32)), get_transform()])

    if ood_dataset == "cifar10":
        dset_fake = tv.datasets.CIFAR10(
            root="~/data", transform=ood_transform,
            download=False, train=False)
    elif ood_dataset == "svhn":
        dset_fake = tv.datasets.SVHN(
            root="~/data", transform=ood_transform,
            download=False, split="test")
    elif ood_dataset == "cifar100":
        dset_fake = tv.datasets.CIFAR100(
            root="~/data", transform=ood_transform,
            download=False, train=False)
    elif ood_dataset == "celeba":
        dset_fake = tv.datasets.CelebA(
            root="~/data", download=False, split="test",
            transform=resize_transform)
    elif ood_dataset == "N":
        dset_fake = Non_dataset(type='N')
    elif ood_dataset == "U":
        dset_fake = Non_dataset(type='U')
    elif ood_dataset == "OODomain":
        dset_fake = Non_dataset(type='OODomain')
    elif ood_dataset == "Constant":
        dset_fake = Non_dataset(type='Constant')
    elif ood_dataset == "dtd":
        dset_fake = tv.datasets.ImageFolder(root=ood_root['dtd'], 
            transform=resize_transform)
    elif ood_dataset == "iSUN":
        dset_fake = tv.datasets.ImageFolder(root = ood_root['iSUN'],
            transform=ood_transform)
    elif ood_dataset == "LSUN":
        dset_fake = tv.datasets.ImageFolder(root = ood_root['LSUN'],
            transform=resize_transform)
    elif ood_dataset == "LSUN_R":
        dset_fake = tv.datasets.ImageFolder(root = ood_root['LSUN_R'],
            transform=ood_transform)
    elif ood_dataset == "places365":
        dset_fake = tv.datasets.ImageFolder(root = ood_root['places365'],
            transform=resize_transform)

    dload_fake = DataLoader(
        dset_fake, batch_size=100, shuffle=False, drop_last=False)
    return dload_fake